{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 176,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---+---------+---------+---+\n",
      "|age|education|wantsMore|  y|\n",
      "+---+---------+---------+---+\n",
      "|<25|      low|      yes|  0|\n",
      "|<25|      low|      yes|  0|\n",
      "|<25|      low|      yes|  0|\n",
      "|<25|      low|      yes|  0|\n",
      "|<25|      low|      yes|  0|\n",
      "+---+---------+---------+---+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "cuse = spark.read.csv('data/cuse_binary.csv', header=True, inferSchema=True)\n",
    "cuse.show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 254,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(int,\n",
       "            {Row(age=u'25-29'): 404,\n",
       "             Row(age=u'30-39'): 612,\n",
       "             Row(age=u'40-49'): 194,\n",
       "             Row(age=u'<25'): 397})"
      ]
     },
     "execution_count": 254,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cuse.columns[0:3]\n",
    "# cuse.select('age').distinct().show()\n",
    "cuse.select('age').rdd.countByValue()\n",
    "# cuse.select('education').rdd.countByValue()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 256,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+-----------+\n",
      "|  age|indexed_age|\n",
      "+-----+-----------+\n",
      "|30-39|        0.0|\n",
      "|  <25|        2.0|\n",
      "|25-29|        1.0|\n",
      "|40-49|        3.0|\n",
      "+-----+-----------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# string index each categorical string columns\n",
    "from pyspark.ml.feature import StringIndexer\n",
    "from pyspark.ml import Pipeline\n",
    "indexers = [StringIndexer(inputCol=column, outputCol=\"indexed_\"+column) for column in ('age', 'education', 'wantsMore')]\n",
    "pipeline = Pipeline(stages=indexers)\n",
    "indexed_cuse = pipeline.fit(cuse).transform(cuse)\n",
    "indexed_cuse.select('age', 'indexed_age').distinct().show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 283,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------+----------------------+----------------------+---+\n",
      "|onehotencode_age|onehotencode_education|onehotencode_wantsMore|  y|\n",
      "+----------------+----------------------+----------------------+---+\n",
      "|   (3,[1],[1.0])|             (1,[],[])|         (1,[0],[1.0])|  0|\n",
      "|   (3,[2],[1.0])|         (1,[0],[1.0])|             (1,[],[])|  1|\n",
      "|   (3,[0],[1.0])|         (1,[0],[1.0])|         (1,[0],[1.0])|  0|\n",
      "|       (3,[],[])|         (1,[0],[1.0])|         (1,[0],[1.0])|  1|\n",
      "|   (3,[2],[1.0])|             (1,[],[])|         (1,[0],[1.0])|  0|\n",
      "+----------------+----------------------+----------------------+---+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# onehotencode each indexed categorical columns\n",
    "from pyspark.ml.feature import OneHotEncoder\n",
    "columns = indexed_cuse.columns[0:3]\n",
    "onehoteencoders = [OneHotEncoder(inputCol=\"indexed_\"+column, outputCol=\"onehotencode_\"+column) for column in columns]\n",
    "pipeline = Pipeline(stages=onehoteencoders)\n",
    "onehotencode_columns = ['onehotencode_age', 'onehotencode_education', 'onehotencode_wantsMore', 'y']\n",
    "onehotencode_cuse = pipeline.fit(indexed_cuse).transform(indexed_cuse).select(onehotencode_columns)\n",
    "onehotencode_cuse.distinct().show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 206,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------+----------------------+----------------------+-----+-------------------+\n",
      "|onehotencode_age|onehotencode_education|onehotencode_wantsMore|label|           features|\n",
      "+----------------+----------------------+----------------------+-----+-------------------+\n",
      "|   (3,[2],[1.0])|             (1,[],[])|         (1,[0],[1.0])|    0|(5,[2,4],[1.0,1.0])|\n",
      "|   (3,[2],[1.0])|             (1,[],[])|         (1,[0],[1.0])|    0|(5,[2,4],[1.0,1.0])|\n",
      "|   (3,[2],[1.0])|             (1,[],[])|         (1,[0],[1.0])|    0|(5,[2,4],[1.0,1.0])|\n",
      "|   (3,[2],[1.0])|             (1,[],[])|         (1,[0],[1.0])|    0|(5,[2,4],[1.0,1.0])|\n",
      "|   (3,[2],[1.0])|             (1,[],[])|         (1,[0],[1.0])|    0|(5,[2,4],[1.0,1.0])|\n",
      "+----------------+----------------------+----------------------+-----+-------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# assemble all feature columns into on single vector column\n",
    "from pyspark.ml.feature import VectorAssembler\n",
    "assembler = VectorAssembler(inputCols=['onehotencode_age', 'onehotencode_education', 'onehotencode_wantsMore'], outputCol='features')\n",
    "cuse_df_2 = assembler.transform(onehotencode_cuse).withColumnRenamed('y', 'label')\n",
    "cuse_df_2.show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 284,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------+----------------------+----------------------+-----+---------+\n",
      "|onehotencode_age|onehotencode_education|onehotencode_wantsMore|label| features|\n",
      "+----------------+----------------------+----------------------+-----+---------+\n",
      "|       (3,[],[])|             (1,[],[])|             (1,[],[])|    0|(5,[],[])|\n",
      "|       (3,[],[])|             (1,[],[])|             (1,[],[])|    0|(5,[],[])|\n",
      "|       (3,[],[])|             (1,[],[])|             (1,[],[])|    0|(5,[],[])|\n",
      "|       (3,[],[])|             (1,[],[])|             (1,[],[])|    0|(5,[],[])|\n",
      "|       (3,[],[])|             (1,[],[])|             (1,[],[])|    0|(5,[],[])|\n",
      "+----------------+----------------------+----------------------+-----+---------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# split data into training and test datasets\n",
    "training, test = cuse_df_2.randomSplit([0.8, 0.2], seed=1234)\n",
    "training.show(5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 285,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "## ======= build cross validation model ===========\n",
    "\n",
    "# estimator\n",
    "from pyspark.ml.regression import GeneralizedLinearRegression\n",
    "glm = GeneralizedLinearRegression(featuresCol='features', labelCol='label', family='binomial')\n",
    "\n",
    "# parameter grid\n",
    "from pyspark.ml.tuning import ParamGridBuilder\n",
    "param_grid = ParamGridBuilder().\\\n",
    "    addGrid(glm.regParam, [0, 0.5, 1, 2, 4]).\\\n",
    "    build()\n",
    "    \n",
    "# evaluator\n",
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "evaluator = BinaryClassificationEvaluator(rawPredictionCol='prediction')\n",
    "\n",
    "# build cross-validation model\n",
    "from pyspark.ml.tuning import CrossValidator\n",
    "cv = CrossValidator(estimator=glm, estimatorParamMaps=param_grid, evaluator=evaluator, numFolds=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 294,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# fit model\n",
    "# cv_model = cv.fit(training)\n",
    "cv_model = cv.fit(cuse_df_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 302,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------+----------------------+----------------------+-----+---------+------------------+\n",
      "|onehotencode_age|onehotencode_education|onehotencode_wantsMore|label| features|        prediction|\n",
      "+----------------+----------------------+----------------------+-----+---------+------------------+\n",
      "|       (3,[],[])|             (1,[],[])|             (1,[],[])|    0|(5,[],[])|0.5140024065151293|\n",
      "|       (3,[],[])|             (1,[],[])|             (1,[],[])|    0|(5,[],[])|0.5140024065151293|\n",
      "|       (3,[],[])|             (1,[],[])|             (1,[],[])|    0|(5,[],[])|0.5140024065151293|\n",
      "|       (3,[],[])|             (1,[],[])|             (1,[],[])|    0|(5,[],[])|0.5140024065151293|\n",
      "|       (3,[],[])|             (1,[],[])|             (1,[],[])|    0|(5,[],[])|0.5140024065151293|\n",
      "+----------------+----------------------+----------------------+-----+---------+------------------+\n",
      "only showing top 5 rows\n",
      "\n",
      "+----------------+----------------------+----------------------+-----+---------+------------------+\n",
      "|onehotencode_age|onehotencode_education|onehotencode_wantsMore|label|features |prediction        |\n",
      "+----------------+----------------------+----------------------+-----+---------+------------------+\n",
      "|(3,[],[])       |(1,[],[])             |(1,[],[])             |0    |(5,[],[])|0.5140024065151293|\n",
      "|(3,[],[])       |(1,[],[])             |(1,[],[])             |0    |(5,[],[])|0.5140024065151293|\n",
      "|(3,[],[])       |(1,[],[])             |(1,[],[])             |0    |(5,[],[])|0.5140024065151293|\n",
      "|(3,[],[])       |(1,[],[])             |(1,[],[])             |0    |(5,[],[])|0.5140024065151293|\n",
      "|(3,[],[])       |(1,[],[])             |(1,[],[])             |0    |(5,[],[])|0.5140024065151293|\n",
      "+----------------+----------------------+----------------------+-----+---------+------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# prediction\n",
    "pred_training_cv = cv_model.transform(training)\n",
    "pred_test_cv = cv_model.transform(test)\n",
    "\n",
    "pred_training_cv.show(5)\n",
    "pred_test_cv.show(5, truncate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 303,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DenseVector([-0.2806, -0.7999, -1.1892, 0.325, -0.833])"
      ]
     },
     "execution_count": 303,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cv_model.bestModel.coefficients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 304,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.056024275169240606"
      ]
     },
     "execution_count": 304,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cv_model.bestModel.intercept"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 305,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GeneralizedLinearRegression_46189bdb9cb6a9a002f6"
      ]
     },
     "execution_count": 305,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cv_model.bestModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 306,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6722119042705738"
      ]
     },
     "execution_count": 306,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluator.evaluate(pred_training_cv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 307,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6769880972917027"
      ]
     },
     "execution_count": 307,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluator.evaluate(pred_test_cv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 301,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GeneralizedLinearRegression_46189bdb9cb6a9a002f6"
      ]
     },
     "execution_count": 301,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cv_model.bestModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 320,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "pdf = pd.DataFrame({\n",
    "        'x1': ['a','a','b','c'],\n",
    "        'x2': ['apple', 'orange', 'orange', 'peach'],\n",
    "        'x3': [1, 1, 2, 4],\n",
    "        'x4': [2.4, 2.5, 3.5, 1.4],\n",
    "        'y1': [1, 0, 0, 1],\n",
    "        'y2': ['yes', 'no', 'no', 'yes']\n",
    "    })\n",
    "\n",
    "df = spark.createDataFrame(pdf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 321,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---+------+---+---+---+---+\n",
      "| x1|    x2| x3| x4| y1| y2|\n",
      "+---+------+---+---+---+---+\n",
      "|  a| apple|  1|2.4|  1|yes|\n",
      "|  a|orange|  1|2.5|  0| no|\n",
      "|  b|orange|  2|3.5|  0| no|\n",
      "|  c| peach|  4|1.4|  1|yes|\n",
      "+---+------+---+---+---+---+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [Root]",
   "language": "python",
   "name": "Python [Root]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
